{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "FRWGyaT2Y25j"
   },
   "source": [
    "# Catalyst segmentation tutorial\n",
    "\n",
    "Authors: [Roman Tezikov](https://github.com/TezRomacH), [Dmitry Bleklov](https://github.com/Bekovmi), [Sergey Kolesnikov](https://github.com/Scitator)\n",
    "\n",
    "[![Catalyst logo](https://raw.githubusercontent.com/catalyst-team/catalyst-pics/master/pics/catalyst_logo.png)](https://github.com/catalyst-team/catalyst)\n",
    "\n",
    "### Colab setup\n",
    "\n",
    "First of all, do not forget to change the runtime type to GPU. <br/>\n",
    "To do so click `Runtime` -> `Change runtime type` -> Select `\"Python 3\"` and `\"GPU\"` -> click `Save`. <br/>\n",
    "After that you can click `Runtime` -> `Run` all and watch the tutorial.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hHJAs8U5Y25m"
   },
   "source": [
    "\n",
    "## Requirements\n",
    "\n",
    "Download and install the latest version of catalyst and other libraries required for this tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this variable will be used in `runner.train` and by default we disable FP16 mode\n",
    "is_fp16_used = False\n",
    "is_alchemy_used = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
    "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
    "colab": {},
    "colab_type": "code",
    "id": "xCoyLtaeY25m",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# for augmentations\n",
    "!pip install albumentations==0.4.3\n",
    "\n",
    "# for pretrained segmentation models fo PyTorch\n",
    "!pip install segmentation-models-pytorch==0.1.0\n",
    "\n",
    "# for TTA\n",
    "!pip install ttach==0.0.2\n",
    "\n",
    "################\n",
    "# Catalyst itself\n",
    "!pip install -U catalyst\n",
    "# For specific version of catalyst, uncomment:\n",
    "# ! pip install git+http://github.com/catalyst-team/catalyst.git@{master/commit_hash}\n",
    "################\n",
    "\n",
    "# for tensorboard\n",
    "!pip install tensorflow\n",
    "\n",
    "# for alchemy experiment logging integration, uncomment this 2 lines below\n",
    "# !pip install -U alchemy-catalyst\n",
    "# is_alchemy_used = True\n",
    "\n",
    "# if Your machine support Apex FP16, uncomment this 3 lines below\n",
    "# !git clone https://github.com/NVIDIA/apex\n",
    "# !pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./apex\n",
    "# is_fp16_used = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MncoA-G0Y25p"
   },
   "source": [
    "### Colab extras – Plotly\n",
    "\n",
    "To intergate visualization library `plotly` to colab, run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import IPython\n",
    "\n",
    "def configure_plotly_browser_state():\n",
    "    display(IPython.core.display.HTML('''\n",
    "        <script src=\"/static/components/requirejs/require.js\"></script>\n",
    "        <script>\n",
    "          requirejs.config({\n",
    "            paths: {\n",
    "              base: '/static/base',\n",
    "              plotly: 'https://cdn.plot.ly/plotly-latest.min.js?noext',\n",
    "            },\n",
    "          });\n",
    "        </script>\n",
    "        '''))\n",
    "\n",
    "\n",
    "IPython.get_ipython().events.register('pre_run_cell', configure_plotly_browser_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting up GPUs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8H65wGVbY25q",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from typing import Callable, List, Tuple\n",
    "\n",
    "import os\n",
    "import torch\n",
    "import catalyst\n",
    "\n",
    "from catalyst.dl import utils\n",
    "\n",
    "print(f\"torch: {torch.__version__}, catalyst: {catalyst.__version__}\")\n",
    "\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"  # \"\" - CPU, \"0\" - 1 GPU, \"0,1\" - MultiGPU\n",
    "\n",
    "SEED = 42\n",
    "utils.set_global_seed(SEED)\n",
    "utils.prepare_cudnn(deterministic=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Reproducibility\n",
    "\n",
    "[![Alchemy logo](https://raw.githubusercontent.com/catalyst-team/catalyst-pics/master/pics/alchemy_logo.png)](https://github.com/catalyst-team/alchemy)\n",
    "\n",
    "To make your research more reproducible and easy to monitor, Catalyst has an integration with [Alchemy](https://alchemy.host) – experiment tracking tool for deep learning.\n",
    "\n",
    "To use monitoring, goto [Alchemy](https://alchemy.host/) and get your personal token."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# for alchemy experiment logging integration, uncomment this 2 lines below\n",
    "# !pip install -U alchemy-catalyst\n",
    "# is_alchemy_used = True"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "if is_alchemy_used:\n",
    "    monitoring_params = {\n",
    "        \"token\": None, # insert your personal token here\n",
    "        \"project\": \"segmentation_example\",\n",
    "        \"group\": \"first_trials\",\n",
    "        \"experiment\": \"first_experiment\",\n",
    "    }\n",
    "    assert monitoring_params[\"token\"] is not None\n",
    "else:\n",
    "    monitoring_params = None"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "-------"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "nWseoJqWY25z"
   },
   "source": [
    "## Dataset\n",
    "\n",
    "As a dataset we will take Carvana - binary segmentation for the \"car\" class."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "34H0dXBYZepr"
   },
   "source": [
    "> If you are on MacOS and you don’t have `wget`, you can install it with: `brew install wget`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After Catalyst installation, `download-gdrive` function become available to download objects from Google Drive.\n",
    "We use it to download datasets.\n",
    "\n",
    "usage: `download-gdrive {FILE_ID} {FILENAME}`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8H8tDrZ6Y250"
   },
   "outputs": [],
   "source": [
    "%%bash\n",
    "\n",
    "download-gdrive 1iYaNijLmzsrMlAdMoUEhhJuo-5bkeAuj segmentation_data.zip\n",
    "extract-archive segmentation_data.zip &>/dev/null"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "g4Vqm9FzY254"
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "ROOT = Path(\"segmentation_data/\")\n",
    "\n",
    "train_image_path = ROOT / \"train\"\n",
    "train_mask_path = ROOT / \"train_masks\"\n",
    "test_image_path = ROOT / \"test\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "vL0sd6__M_98"
   },
   "source": [
    "Collect images and masks into variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "wU9IPpyAfhOy"
   },
   "outputs": [],
   "source": [
    "ALL_IMAGES = sorted(train_image_path.glob(\"*.jpg\"))\n",
    "len(ALL_IMAGES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ElvJQrlOf4Gm"
   },
   "outputs": [],
   "source": [
    "ALL_MASKS = sorted(train_mask_path.glob(\"*.gif\"))\n",
    "len(ALL_MASKS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "6mcVBd_WjI43"
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from skimage.io import imread as gif_imread\n",
    "from catalyst import utils\n",
    "\n",
    "\n",
    "def show_examples(name: str, image: np.ndarray, mask: np.ndarray):\n",
    "    plt.figure(figsize=(10, 14))\n",
    "    plt.subplot(1, 2, 1)\n",
    "    plt.imshow(image)\n",
    "    plt.title(f\"Image: {name}\")\n",
    "\n",
    "    plt.subplot(1, 2, 2)\n",
    "    plt.imshow(mask)\n",
    "    plt.title(f\"Mask: {name}\")\n",
    "\n",
    "\n",
    "def show(index: int, images: List[Path], masks: List[Path], transforms=None) -> None:\n",
    "    image_path = images[index]\n",
    "    name = image_path.name\n",
    "\n",
    "    image = utils.imread(image_path)\n",
    "    mask = gif_imread(masks[index])\n",
    "\n",
    "    if transforms is not None:\n",
    "        temp = transforms(image=image, mask=mask)\n",
    "        image = temp[\"image\"]\n",
    "        mask = temp[\"mask\"]\n",
    "\n",
    "    show_examples(name, image, mask)\n",
    "\n",
    "def show_random(images: List[Path], masks: List[Path], transforms=None) -> None:\n",
    "    length = len(images)\n",
    "    index = random.randint(0, length - 1)\n",
    "    show(index, images, masks, transforms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "f0zVSTYtk8hf"
   },
   "source": [
    "You can restart the cell below to see more examples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "r0EZVF1pk3tC"
   },
   "outputs": [],
   "source": [
    "show_random(ALL_IMAGES, ALL_MASKS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "JpEVELsNNMTO"
   },
   "source": [
    "The dataset below reads images and masks and optionally applies augmentation to them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Grrv0FqpY25-"
   },
   "outputs": [],
   "source": [
    "from typing import List\n",
    "\n",
    "from torch.utils.data import Dataset\n",
    "\n",
    "\n",
    "class SegmentationDataset(Dataset):\n",
    "    def __init__(\n",
    "        self,\n",
    "        images: List[Path],\n",
    "        masks: List[Path] = None,\n",
    "        transforms=None\n",
    "    ) -> None:\n",
    "        self.images = images\n",
    "        self.masks = masks\n",
    "        self.transforms = transforms\n",
    "\n",
    "    def __len__(self) -> int:\n",
    "        return len(self.images)\n",
    "\n",
    "    def __getitem__(self, idx: int) -> dict:\n",
    "        image_path = self.images[idx]\n",
    "        image = utils.imread(image_path)\n",
    "        \n",
    "        result = {\"image\": image}\n",
    "        \n",
    "        if self.masks is not None:\n",
    "            mask = gif_imread(self.masks[idx])\n",
    "            result[\"mask\"] = mask\n",
    "        \n",
    "        if self.transforms is not None:\n",
    "            result = self.transforms(**result)\n",
    "        \n",
    "        result[\"filename\"] = image_path.name\n",
    "\n",
    "        return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "-------"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "DsI3ZS2asQqg"
   },
   "source": [
    "### Augmentations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RF0wtzsjNZQ5"
   },
   "source": [
    "[![Albumentation logo](https://albumentations.readthedocs.io/en/latest/_static/logo.png)](https://github.com/albu/albumentations)\n",
    "\n",
    "The [albumentation](https://github.com/albu/albumentations) library works with images and masks at the same time, which is what we need."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "zNdK5P0UY26A"
   },
   "outputs": [],
   "source": [
    "import albumentations as albu\n",
    "from albumentations.pytorch import ToTensor\n",
    "\n",
    "\n",
    "def pre_transforms(image_size=224):\n",
    "    return [albu.Resize(image_size, image_size, p=1)]\n",
    "\n",
    "\n",
    "def hard_transforms():\n",
    "    result = [\n",
    "      albu.RandomRotate90(),\n",
    "      albu.Cutout(),\n",
    "      albu.RandomBrightnessContrast(\n",
    "          brightness_limit=0.2, contrast_limit=0.2, p=0.3\n",
    "      ),\n",
    "      albu.GridDistortion(p=0.3),\n",
    "      albu.HueSaturationValue(p=0.3)\n",
    "    ]\n",
    "\n",
    "    return result\n",
    "  \n",
    "\n",
    "def resize_transforms(image_size=224):\n",
    "    BORDER_CONSTANT = 0\n",
    "    pre_size = int(image_size * 1.5)\n",
    "\n",
    "    random_crop = albu.Compose([\n",
    "      albu.SmallestMaxSize(pre_size, p=1),\n",
    "      albu.RandomCrop(\n",
    "          image_size, image_size, p=1\n",
    "      )\n",
    "\n",
    "    ])\n",
    "\n",
    "    rescale = albu.Compose([albu.Resize(image_size, image_size, p=1)])\n",
    "\n",
    "    random_crop_big = albu.Compose([\n",
    "      albu.LongestMaxSize(pre_size, p=1),\n",
    "      albu.RandomCrop(\n",
    "          image_size, image_size, p=1\n",
    "      )\n",
    "\n",
    "    ])\n",
    "\n",
    "    # Converts the image to a square of size image_size x image_size\n",
    "    result = [\n",
    "      albu.OneOf([\n",
    "          random_crop,\n",
    "          rescale,\n",
    "          random_crop_big\n",
    "      ], p=1)\n",
    "    ]\n",
    "\n",
    "    return result\n",
    "  \n",
    "def post_transforms():\n",
    "    # we use ImageNet image normalization\n",
    "    # and convert it to torch.Tensor\n",
    "    return [albu.Normalize(), ToTensor()]\n",
    "  \n",
    "def compose(transforms_to_compose):\n",
    "    # combine all augmentations into one single pipeline\n",
    "    result = albu.Compose([\n",
    "      item for sublist in transforms_to_compose for item in sublist\n",
    "    ])\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "mrRzeFppnPMt"
   },
   "outputs": [],
   "source": [
    "train_transforms = compose([\n",
    "    resize_transforms(), \n",
    "    hard_transforms(), \n",
    "    post_transforms()\n",
    "])\n",
    "valid_transforms = compose([pre_transforms(), post_transforms()])\n",
    "\n",
    "show_transforms = compose([resize_transforms(), hard_transforms()])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CTiQd-frN74v"
   },
   "source": [
    "Let's look at the augmented results. <br/>\n",
    "You can restart the cell below to see more examples of augmentations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "gGIbhBm1orXt"
   },
   "outputs": [],
   "source": [
    "show_random(ALL_IMAGES, ALL_MASKS, transforms=show_transforms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RvLlx1OvosGX"
   },
   "source": [
    "-------"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "fGotRWGjyYOw"
   },
   "source": [
    "## Loaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "WIfyQYAhY26C"
   },
   "outputs": [],
   "source": [
    "import collections\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "def get_loaders(\n",
    "    images: List[Path],\n",
    "    masks: List[Path],\n",
    "    random_state: int,\n",
    "    valid_size: float = 0.2,\n",
    "    batch_size: int = 32,\n",
    "    num_workers: int = 4,\n",
    "    train_transforms_fn = None,\n",
    "    valid_transforms_fn = None,\n",
    ") -> dict:\n",
    "\n",
    "    indices = np.arange(len(images))\n",
    "\n",
    "    # Let's divide the data set into train and valid parts.\n",
    "    train_indices, valid_indices = train_test_split(\n",
    "      indices, test_size=valid_size, random_state=random_state, shuffle=True\n",
    "    )\n",
    "\n",
    "    np_images = np.array(images)\n",
    "    np_masks = np.array(masks)\n",
    "\n",
    "    # Creates our train dataset\n",
    "    train_dataset = SegmentationDataset(\n",
    "      images = np_images[train_indices].tolist(),\n",
    "      masks = np_masks[train_indices].tolist(),\n",
    "      transforms = train_transforms_fn\n",
    "    )\n",
    "\n",
    "    # Creates our valid dataset\n",
    "    valid_dataset = SegmentationDataset(\n",
    "      images = np_images[valid_indices].tolist(),\n",
    "      masks = np_masks[valid_indices].tolist(),\n",
    "      transforms = valid_transforms_fn\n",
    "    )\n",
    "\n",
    "    # Catalyst uses normal torch.data.DataLoader\n",
    "    train_loader = DataLoader(\n",
    "      train_dataset,\n",
    "      batch_size=batch_size,\n",
    "      shuffle=True,\n",
    "      num_workers=num_workers,\n",
    "      drop_last=True,\n",
    "    )\n",
    "\n",
    "    valid_loader = DataLoader(\n",
    "      valid_dataset,\n",
    "      batch_size=batch_size,\n",
    "      shuffle=False,\n",
    "      num_workers=num_workers,\n",
    "      drop_last=True,\n",
    "    )\n",
    "\n",
    "    # And excpect to get an OrderedDict of loaders\n",
    "    loaders = collections.OrderedDict()\n",
    "    loaders[\"train\"] = train_loader\n",
    "    loaders[\"valid\"] = valid_loader\n",
    "\n",
    "    return loaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "OG6_pgmkxk_b"
   },
   "outputs": [],
   "source": [
    "if is_fp16_used:\n",
    "    batch_size = 64\n",
    "else:\n",
    "    batch_size = 32\n",
    "\n",
    "print(f\"batch_size: {batch_size}\")\n",
    "\n",
    "loaders = get_loaders(\n",
    "    images=ALL_IMAGES,\n",
    "    masks=ALL_MASKS,\n",
    "    random_state=SEED,\n",
    "    train_transforms_fn=train_transforms,\n",
    "    valid_transforms_fn=valid_transforms,\n",
    "    batch_size=batch_size\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "-------"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "5BF2Vgdly9RR"
   },
   "source": [
    "## Experiment\n",
    "### Model\n",
    "\n",
    "Catalyst has [several segmentation models](https://github.com/catalyst-team/catalyst/blob/master/catalyst/contrib/models/segmentation/__init__.py#L16) (Unet, Linknet, FPN, PSPnet and their versions with pretrain from Resnet).\n",
    "\n",
    "> You can read more about them in [our blog post](https://github.com/catalyst-team/catalyst-info#catalyst-info-1-segmentation-models).\n",
    "\n",
    "But for now let's take the model from [segmentation_models.pytorch](https://github.com/qubvel/segmentation_models.pytorch) (SMP for short). The same segmentation architectures have been implemented in this repository, but there are many more pre-trained encoders.\n",
    "\n",
    "[![Segmentation Models logo](https://raw.githubusercontent.com/qubvel/segmentation_models.pytorch/master/pics/logo-small-w300.png)](https://github.com/qubvel/segmentation_models.pytorch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Zm7JsNrczOQG"
   },
   "outputs": [],
   "source": [
    "import segmentation_models_pytorch as smp\n",
    "\n",
    "# We will use Feature Pyramid Network with pre-trained ResNeXt50 backbone\n",
    "model = smp.FPN(encoder_name=\"resnext50_32x4d\", classes=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "GXxJFBUkybYs"
   },
   "source": [
    "### Model training\n",
    "\n",
    "We will optimize loss as the sum of IoU, Dice and BCE, specifically this function: $IoU + Dice + 0.8*BCE$.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "nhVSEDGbY26G"
   },
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "\n",
    "from catalyst.contrib.nn import DiceLoss, IoULoss\n",
    "\n",
    "# we have multiple criterions\n",
    "criterion = {\n",
    "    \"dice\": DiceLoss(),\n",
    "    \"iou\": IoULoss(),\n",
    "    \"bce\": nn.BCEWithLogitsLoss()\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "IirWWkf8PeXh"
   },
   "outputs": [],
   "source": [
    "from torch import optim\n",
    "\n",
    "from catalyst.contrib.nn import RAdam, Lookahead\n",
    "\n",
    "learning_rate = 0.001\n",
    "encoder_learning_rate = 0.0005\n",
    "\n",
    "# Since we use a pre-trained encoder, we will reduce the learning rate on it.\n",
    "layerwise_params = {\"encoder*\": dict(lr=encoder_learning_rate, weight_decay=0.00003)}\n",
    "\n",
    "# This function removes weight_decay for biases and applies our layerwise_params\n",
    "model_params = utils.process_model_params(model, layerwise_params=layerwise_params)\n",
    "\n",
    "# Catalyst has new SOTA optimizers out of box\n",
    "base_optimizer = RAdam(model_params, lr=learning_rate, weight_decay=0.0003)\n",
    "optimizer = Lookahead(base_optimizer)\n",
    "\n",
    "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.25, patience=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "3pTmRiOfY26I"
   },
   "outputs": [],
   "source": [
    "if is_alchemy_used:\n",
    "    from catalyst.dl import SupervisedAlchemyRunner as SupervisedRunner\n",
    "else:\n",
    "    from catalyst.dl import SupervisedRunner\n",
    "\n",
    "num_epochs = 3\n",
    "logdir = \"./logs/segmentation\"\n",
    "\n",
    "device = utils.get_device()\n",
    "print(f\"device: {device}\")\n",
    "\n",
    "if is_fp16_used:\n",
    "    fp16_params = dict(opt_level=\"O1\") # params for FP16\n",
    "else:\n",
    "    fp16_params = None\n",
    "\n",
    "print(f\"FP16 params: {fp16_params}\")\n",
    "\n",
    "\n",
    "# by default SupervisedRunner uses \"features\" and \"targets\",\n",
    "# in our case we get \"image\" and \"mask\" keys in dataset __getitem__\n",
    "runner = SupervisedRunner(device=device, input_key=\"image\", input_target_key=\"mask\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "n5mvw0fO1xI5"
   },
   "source": [
    "### Monitoring in tensorboard"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you do not have a Tensorboard opened after you have run the cell below, try running the cell again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%load_ext tensorboard\n",
    "%tensorboard --logdir {logdir}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Running train-loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "VqBlC5_iY26K"
   },
   "outputs": [],
   "source": [
    "from catalyst.dl.callbacks import DiceCallback, IouCallback, \\\n",
    "  CriterionCallback, CriterionAggregatorCallback\n",
    "\n",
    "runner.train(\n",
    "    model=model,\n",
    "    criterion=criterion,\n",
    "    optimizer=optimizer,\n",
    "    scheduler=scheduler,\n",
    "    \n",
    "    # our dataloaders\n",
    "    loaders=loaders,\n",
    "    \n",
    "    callbacks=[\n",
    "        # Each criterion is calculated separately.\n",
    "        CriterionCallback(\n",
    "            input_key=\"mask\",\n",
    "            prefix=\"loss_dice\",\n",
    "            criterion_key=\"dice\"\n",
    "        ),\n",
    "        CriterionCallback(\n",
    "            input_key=\"mask\",\n",
    "            prefix=\"loss_iou\",\n",
    "            criterion_key=\"iou\"\n",
    "        ),\n",
    "        CriterionCallback(\n",
    "            input_key=\"mask\",\n",
    "            prefix=\"loss_bce\",\n",
    "            criterion_key=\"bce\"\n",
    "        ),\n",
    "        \n",
    "        # And only then we aggregate everything into one loss.\n",
    "        CriterionAggregatorCallback(\n",
    "            prefix=\"loss\",\n",
    "            loss_aggregate_fn=\"weighted_sum\", # can be \"sum\", \"weighted_sum\" or \"mean\"\n",
    "            # because we want weighted sum, we need to add scale for each loss\n",
    "            loss_keys={\"loss_dice\": 1.0, \"loss_iou\": 1.0, \"loss_bce\": 0.8},\n",
    "        ),\n",
    "        \n",
    "        # metrics\n",
    "        DiceCallback(input_key=\"mask\"),\n",
    "        IouCallback(input_key=\"mask\"),\n",
    "    ],\n",
    "    # path to save logs\n",
    "    logdir=logdir,\n",
    "    \n",
    "    num_epochs=num_epochs,\n",
    "    \n",
    "    # save our best checkpoint by IoU metric\n",
    "    main_metric=\"iou\",\n",
    "    # IoU needs to be maximized.\n",
    "    minimize_metric=False,\n",
    "    \n",
    "    # for FP16. It uses the variable from the very first cell\n",
    "    fp16=fp16_params,\n",
    "    \n",
    "    # for external monitoring tools, like Alchemy\n",
    "    monitoring_params=monitoring_params,\n",
    "    \n",
    "    # prints train logs\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training analysis\n",
    "\n",
    "The `utils.plot_metrics` method reads tensorboard logs from the logdir and plots beautiful metrics with `plotly` package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# tensorboard should be enought, uncomment to check plotly version\n",
    "# it can take a while (colab issue)\n",
    "# utils.plot_metrics(\n",
    "#     logdir=logdir, \n",
    "#     # specify which metrics we want to plot\n",
    "#     metrics=[\"loss\", \"accuracy01\", \"auc/_mean\", \"f1_score\", \"_base/lr\"]\n",
    "# )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "tRH8jhg3Q_zB"
   },
   "source": [
    "## Model inference\n",
    "\n",
    "Let's look at model predictions.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "XyugaCieRnx2"
   },
   "outputs": [],
   "source": [
    "TEST_IMAGES = sorted(test_image_path.glob(\"*.jpg\"))\n",
    "\n",
    "# create test dataset\n",
    "test_dataset = SegmentationDataset(\n",
    "    TEST_IMAGES, \n",
    "    transforms=valid_transforms\n",
    ")\n",
    "\n",
    "num_workers: int = 4\n",
    "\n",
    "infer_loader = DataLoader(\n",
    "    test_dataset,\n",
    "    batch_size=batch_size,\n",
    "    shuffle=False,\n",
    "    num_workers=num_workers\n",
    ")\n",
    "\n",
    "# this get predictions for the whole loader\n",
    "predictions = runner.predict_loader(\n",
    "    model=model,\n",
    "    loader=infer_loader,\n",
    "    resume=f\"{logdir}/checkpoints/best.pth\",\n",
    "    verbose=False,\n",
    ")\n",
    "\n",
    "print(type(predictions))\n",
    "print(predictions.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "c57ZjVTxT-9s"
   },
   "outputs": [],
   "source": [
    "threshold = 0.5\n",
    "max_count = 5\n",
    "\n",
    "for i, (features, logits) in enumerate(zip(test_dataset, predictions)):\n",
    "    image = utils.tensor_to_ndimage(features[\"image\"])\n",
    "\n",
    "    mask_ = torch.from_numpy(logits[0]).sigmoid()\n",
    "    mask = utils.detach(mask_ > threshold).astype(\"float\")\n",
    "        \n",
    "    show_examples(name=\"\", image=image, mask=mask)\n",
    "    \n",
    "    if i >= max_count:\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "hBm6KJg8wTGM"
   },
   "source": [
    "## Model tracing\n",
    "\n",
    "Catalyst allows you to use Runner to make [tracing](https://pytorch.org/docs/stable/jit.html) models.\n",
    "\n",
    "> How to do this in the Config API, we wrote in [our blog (issue \\#2)](https://github.com/catalyst-team/catalyst-info#catalyst-info-2-tracing-with-torchjit)\n",
    "\n",
    "For this purpose it is necessary to pass in a method `trace ` model and a batch on which `predict_batch ` will be executed:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = next(iter(loaders[\"valid\"]))\n",
    "# saves to `logdir` and returns a `ScriptModule` class\n",
    "runner.trace(model=model, batch=batch, logdir=logdir, fp16=is_fp16_used)\n",
    "\n",
    "!ls {logdir}/trace/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After this, you can easily load the model and predict anything!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from catalyst.dl.utils import trace\n",
    "\n",
    "if is_fp16_used:\n",
    "    model = trace.load_traced_model(\n",
    "        f\"{logdir}/trace/traced-forward-opt_O1.pth\", \n",
    "        device=\"cuda\", \n",
    "        opt_level=\"O1\"\n",
    "    )\n",
    "else:\n",
    "    model = trace.load_traced_model(\n",
    "        f\"{logdir}/trace/traced-forward.pth\", \n",
    "        device=\"cpu\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_input = batch[\"image\"].to(\"cuda\" if is_fp16_used else \"cpu\")\n",
    "model(model_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Advanced: Custom Callbacks\n",
    "\n",
    "Let's plot the heatmap of predicted masks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import collections\n",
    "\n",
    "from catalyst.dl import Callback, CallbackOrder, State\n",
    "\n",
    "\n",
    "class CustomInferCallback(Callback):\n",
    "    def __init__(self):\n",
    "        super().__init__(CallbackOrder.Internal)\n",
    "        self.heatmap = None\n",
    "        self.counter = 0\n",
    "\n",
    "    def on_loader_start(self, state: State):\n",
    "        self.predictions = None\n",
    "        self.counter = 0\n",
    "\n",
    "    def on_batch_end(self, state: State):\n",
    "        # data from the Dataloader\n",
    "        # image, mask = state.input[\"image\"], state.input[\"mask\"]\n",
    "        logits = state.output[\"logits\"]\n",
    "        probabilities = torch.sigmoid(logits)\n",
    "\n",
    "        self.heatmap = (\n",
    "            probabilities \n",
    "            if self.heatmap is None \n",
    "            else self.heatmap + probabilities\n",
    "        )\n",
    "        self.counter += len(probabilities)\n",
    "\n",
    "    def on_loader_end(self, state: State):\n",
    "        self.heatmap = self.heatmap.sum(axis=0)\n",
    "        self.heatmap /= self.counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "from catalyst.dl.callbacks import CheckpointCallback\n",
    "\n",
    "\n",
    "infer_loaders = {\"infer\": loaders[\"valid\"]}\n",
    "model = smp.FPN(encoder_name=\"resnext50_32x4d\", classes=1)\n",
    "\n",
    "device = utils.get_device()\n",
    "if is_fp16_used:\n",
    "    fp16_params = dict(opt_level=\"O1\") # params for FP16\n",
    "else:\n",
    "    fp16_params = None\n",
    "\n",
    "runner = SupervisedRunner(device=device, input_key=\"image\", input_target_key=\"mask\")\n",
    "runner.infer(\n",
    "    model=model,\n",
    "    loaders=infer_loaders,\n",
    "    callbacks=OrderedDict([\n",
    "        (\"loader\", CheckpointCallback(resume=f\"{logdir}/checkpoints/best.pth\")),\n",
    "        (\"infer\", CustomInferCallback())\n",
    "    ]),\n",
    "    fp16=fp16_params,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "%matplotlib inline \n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "heatmap = utils.detach(runner.callbacks[\"infer\"].heatmap[0])\n",
    "plt.figure(figsize=(20, 9))\n",
    "plt.imshow(heatmap, cmap=\"hot\", interpolation=\"nearest\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Advanced: test-time augmentations (TTA)\n",
    "\n",
    "There is [ttach](https://github.com/qubvel/ttach) is a new awesome library for test-time augmentation for segmentation or classification tasks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import ttach as tta\n",
    "\n",
    "# D4 makes horizontal and vertical flips + rotations for [0, 90, 180, 270] angels.\n",
    "# and then merges the result masks with merge_mode=\"mean\"\n",
    "tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode=\"mean\")\n",
    "\n",
    "tta_runner = SupervisedRunner(\n",
    "    model=tta_model,\n",
    "    device=utils.get_device(),\n",
    "    input_key=\"image\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "infer_loader = DataLoader(\n",
    "    test_dataset,\n",
    "    batch_size=1,\n",
    "    shuffle=False,\n",
    "    num_workers=num_workers\n",
    ")\n",
    "\n",
    "batch = next(iter(infer_loader))\n",
    "\n",
    "# predict_batch will automatically move the batch to the Runner's device\n",
    "tta_predictions = tta_runner.predict_batch(batch)\n",
    "\n",
    "# shape is `batch_size x channels x height x width`\n",
    "print(tta_predictions[\"logits\"].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see our mask after TTA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "threshold = 0.5\n",
    "\n",
    "image = utils.tensor_to_ndimage(batch[\"image\"][0])\n",
    "\n",
    "mask_ = tta_predictions[\"logits\"][0, 0].sigmoid()\n",
    "mask = utils.detach(mask_ > threshold).astype(\"float\")\n",
    "\n",
    "show_examples(name=\"\", image=image, mask=mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "name": "segmentation-tutorial.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}